package org.zjvis.datascience.spark.algorithm;


import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.storage.StorageLevel;
import org.zjvis.datascience.spark.util.OptionHelper;
import org.zjvis.datascience.spark.util.OutputResult;
import org.zjvis.datascience.spark.util.UtilTool;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * @description Spark 异常检测算子 [基于统计假设校验]
 * @date 2021-12-23
 */
public class StatisticsAnomaly extends BaseAlgorithm {
    private String[] featureCols;

    private String targetTable;

    private String sourceTable;

    private static String sql_tpl = "select a.*, (CASE WHEN (a.`%s` > b.avgvalue + 3 * b.stdvalue or a.`%s` < b.avgvalue - 3 * b.stdvalue) THEN '-1' ELSE '1' END) as label " +
            "from (select stddev(`%s`) as stdvalue, avg(`%s`) as avgvalue from %s) as b, %s as a";

    public StatisticsAnomaly(SparkSession sparkSession) {
        super(sparkSession);
    }

    public boolean parseParams(String[] args) {
        super.parseParams(args);

        options.addOption(OptionHelper.getSpecOption("source"));
        options.addOption(OptionHelper.getSpecOption("target"));
        options.addOption(OptionHelper.getSpecOption("featureCols"));

        boolean flag = this.initCommandLine(args);

        if (!flag) {
            logger.error("initCommandLine fail!!!");
            return false;
        }
        if (cmd.hasOption("f") || cmd.hasOption("featureCols")) {
            featureCols = cmd.getOptionValue("featureCols").split(",");
        }
        if (cmd.hasOption("s") || cmd.hasOption("source")) {
            sourceTable = cmd.getOptionValue("source");
        }
        if (cmd.hasOption("t") || cmd.hasOption("target")) {
            targetTable = cmd.getOptionValue("target");
        }

        return true;
    }

    public boolean beginAlgorithm() {
        Dataset<Row> dataset;
        if (isHive) {
            String sql = String.format(sql_tpl, featureCols[0], featureCols[0], featureCols[0], featureCols[0], sourceTable, sourceTable);
            dataset = sparkSession.sql(sql).persist(StorageLevel.MEMORY_AND_DISK());
        } else {
            if (isSample) {
                String cacheTable = cacheUtil.modifyCacheTableName(sourceTable);
                if (cacheUtil.isCacheTableExists(cacheTable)) {
                } else {
                    Dataset<Row> cacheRdd = UtilTool.readFromGreenPlum(sparkSession, sourceTable, idCol, sampleNumber);
                    if (!cacheUtil.cacheTableForDataset(cacheRdd, cacheTable)) {
                        logger.error("table = {} cache fail!!!", cacheTable);
                        return false;
                    }
                    logger.info("table = {} cache success", cacheTable);
                }
                String sql = String.format(sql_tpl, featureCols[0], featureCols[0], featureCols[0], featureCols[0], cacheTable, cacheTable);
                dataset = sparkSession.sql(sql).persist(StorageLevel.MEMORY_AND_DISK());
            } else {
                Dataset<Row> tmpDF = UtilTool.readFromGreenPlum(sparkSession, sourceTable, idCol);
                String tmpView = String.format("tmpView_%s", uniqueKey);
                tmpDF.createOrReplaceTempView(tmpView);
                String sql = String.format(sql_tpl, featureCols[0], featureCols[0], featureCols[0], featureCols[0], tmpView, tmpView);
                dataset = sparkSession.sql(sql).persist(StorageLevel.MEMORY_AND_DISK());
            }
        }

        if (isHive) {
            this.saveHiveTable(dataset, targetTable);
        } else {
            UtilTool.saveGreenplumTable(dataset, targetTable);
        }

        Map<String, Object> inputKV = new HashMap<>();
        List<String> outputTables = new ArrayList<>();
        outputTables.add(targetTable);

        String metaPath = String.format(UtilTool.RUN_META_PATH, algorithmName, uniqueKey);

        OutputResult outputResult = this.buildOutputResult(outputTables, 0, "", inputKV);

        if (isHive) {
            this.saveOutputResultForHive(sparkSession, outputResult, metaPath);
        } else {
            this.saveOutputResultForMysql(sparkSession, outputResult);
        }
        retResult = outputResult;
        dataset.unpersist();
        return true;
    }
}
